#!/usr/bin/env python

import sys
import os
import time
import signal
import random
import traceback
import errno
import multiprocessing
from optparse import OptionParser

from multiprocessing import SimpleQueue, Process, Event

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, 'common')))
import utils

r = utils.import_python_driver()


def call_ignore_interrupt(fun):
    """Call a function ignoring EINTR errors."""
    while True:
        try:
            return fun()
        except IOError as ex:
            if ex.errno != errno.EINTR:
                raise


def stress_client_proc(options, start_event, exit_event, stat_queue, host_offset, random_seed):
    host_offset = host_offset % len(options["hosts"])
    host, port = options["hosts"][host_offset]
    ops_per_conn = options["ops_per_conn"]

    random.seed(random_seed)

    ops_done = int(random.random() * ops_per_conn)
    loop_cond = (lambda: True) if ops_per_conn == 0 else (lambda: ops_done < ops_per_conn)

    runner = QueryThrottler(options, stat_queue)
    stat_queue.put("ready")
    start_event.wait()

    while not exit_event.is_set():
        try:
            with r.connect(host, port) as conn:
                while loop_cond() and not exit_event.is_set():
                    runner.send_query(conn)
                    ops_done += 1
                ops_done = 0
        except Exception as ex:
            stat_queue.put({"timestamp": time.time(), "latency": 0.0, "errors": [str(ex)]})


def spawn_clients(options, start_event, exit_event, stat_queue):
    num_clients = options["clients"]
    client_procs = []
    host_offset = 0

    random_seed = options["seed"]
    if random_seed is None:
        random_seed = random.random()

    print(f"Random seed used: {random_seed:f}", file=sys.stderr)
    random.seed(random_seed)

    for _ in range(num_clients):
        proc = Process(target=stress_client_proc, args=(
            options,
            start_event,
            exit_event,
            stat_queue,
            host_offset,
            random.random()
        ))
        proc.start()
        client_procs.append(proc)
        host_offset += 1

    for _ in range(num_clients):
        response = call_ignore_interrupt(stat_queue.get)
        if response != "ready":
            raise RuntimeError(f"Unexpected response from client: {response}")

    return client_procs


def stop_clients(exit_event, child_procs, timeout):
    exit_event.set()
    for proc in child_procs:
        proc.terminate()

    start_wait = time.time()
    while time.time() - start_wait < timeout:
        if all(not proc.is_alive() for proc in child_procs):
            break
        time.sleep(0.1)

    failed_procs = [proc for proc in child_procs if proc.is_alive() or (proc.exitcode not in [0, -15])]
    if failed_procs:
        reasons = []
        for proc in failed_procs:
            idx = child_procs.index(proc) + 1
            if proc.exitcode:
                reason = f"process {idx} failed with code {proc.exitcode}"
            else:
                reason = f"process {idx} timed out"
            reasons.append(reason)
        print(f"Sub-processes failed: {', '.join(reasons)}", file=sys.stderr)
        sys.exit(1)
    return time.time()


def print_stats(stats, start_time, end_time, num_clients):
    duration = end_time - start_time
    print(f"Duration: {duration:0.3f} seconds", file=sys.stderr)
    print("", file=sys.stderr)
    print("Operations data: ", file=sys.stderr)

    if duration != 0.0:
        per_sec = stats["count"] / duration
        per_sec_per_client = per_sec / num_clients
    else:
        per_sec = float('inf')
        per_sec_per_client = float('inf')

    avg_latency = (stats["latency"] / stats["count"]) if stats["count"] != 0 else 0.0
    table = [
        ["total", "per sec", "per sec client avg", "avg latency"],
        [str(stats["count"]), f"{per_sec:0.3f}", f"{per_sec_per_client:0.3f}", f"{avg_latency:0.6f}"]
    ]

    column_widths = [max(len(row[i]) for row in table) + 2 for i in range(len(table[0]))]
    format_str = (f"{{:<{column_widths[0]}}}" +
                  "".join([f"{{:>{w}}}" for w in column_widths[1:]]))

    total_client_time = duration * num_clients
    rql_time_spent = stats["latency"]
    percent_time = 100 * rql_time_spent / total_client_time if total_client_time else 0.0
    print(f"Percent time clients spent in ReQL space: {percent_time:.2f}", file=sys.stderr)

    for row in table:
        print(format_str.format(*row), file=sys.stderr)

    if stats["errors"]:
        print("", file=sys.stderr)
        print("Errors encountered:", file=sys.stderr)
        for error, count in stats["errors"].items():
            print(f"{error}: {count}", file=sys.stderr)


def interrupt_handler(sig, frame, exit_event, parent_pid):
    if os.getpid() == parent_pid:
        exit_event.set()


class QueryThrottler:
    def __init__(self, options, stat_queue):
        self.workload = options["workload"]
        self.stat_queue = stat_queue
        if options["ops_per_sec"] != 0:
            self.secs_per_op = float(options["clients"]) / options["ops_per_sec"]
        else:
            self.secs_per_op = 0
        self.next_query_time = None

    def send_query(self, conn):
        if self.secs_per_op:
            if self.next_query_time is None:
                time.sleep(random.random() * self.secs_per_op)
                self.next_query_time = time.time()

            now = time.time()
            time_overdue = now - self.next_query_time

            if time_overdue > 10 * self.secs_per_op:
                self.next_query_time = now - (10 * self.secs_per_op)
            else:
                self.next_query_time += self.secs_per_op
                if time_overdue < 0:
                    time.sleep(-time_overdue)

        start_time = time.time()
        result = {"timestamp": start_time}
        try:
            result.update(self.workload.run(conn))
        except (r.ReqlError, r.ReqlDriverError) as ex:
            result.setdefault("errors", []).append(str(ex))
        except (IOError, OSError) as ex:
            if ex.errno != errno.EINTR:
                raise
            result.setdefault("errors", []).append("Interrupted system call")
        result["latency"] = time.time() - start_time

        self.stat_queue.put(result)


def stress_controller(options):
    stat_queue = SimpleQueue()
    start_event = Event()
    exit_event = Event()
    child_procs = spawn_clients(options, start_event, exit_event, stat_queue)
    stats = {"count": 0, "latency": 0.0, "errors": {}}
    start_time = time.time()

    parent_pid = os.getpid()
    signal.signal(signal.SIGINT, lambda sig, frame: interrupt_handler(sig, frame, exit_event, parent_pid))

    try:
        start_event.set()

        while not exit_event.is_set():
            if not stat_queue.empty():
                stat = call_ignore_interrupt(stat_queue.get)
                if not options["quiet"]:
                    print(f"{stat['timestamp']:0.3f}: {stat['latency']:0.6f}")
                    sys.stdout.flush()

                stats["count"] += 1
                stats["latency"] += stat["latency"]
                for error in stat.get("errors", []):
                    if not options["ignore_errors"]:
                        print(f"  - {error}")
                    stats["errors"][error] = stats["errors"].get(error, 0) + 1

                if options["op_count"] and stats["count"] >= options["op_count"]:
                    exit_event.set()
            else:
                if options["duration"] and (time.time() - start_time > options["duration"]):
                    exit_event.set()
                time.sleep(0.1)
    except Exception:
        traceback.print_exc()
        sys.exit(1)
    finally:
        stop_timeout = max(2.0, 3.0 / (options["ops_per_sec"] or 1))
        end_time = stop_clients(exit_event, child_procs, stop_timeout)

        while not stat_queue.empty():
            stat = call_ignore_interrupt(stat_queue.get)
            stats["count"] += 1
            stats["latency"] += stat["latency"]
            for error in stat.get("errors", []):
                stats["errors"][error] = stats["errors"].get(error, 0) + 1

        print_stats(stats, start_time, end_time, options["clients"])


if __name__ == "__main__":
    parser = OptionParser()
    parser.add_option("--seed", dest="seed", metavar="FLOAT", default=None, type="float")
    parser.add_option("--table", dest="db_table", metavar="DB.TABLE", default=None, type="string")
    parser.add_option("--ops-per-sec", dest="ops_per_sec", metavar="NUMBER", default=100, type="float")
    parser.add_option("--ops-per-conn", dest="ops_per_conn", metavar="NUMBER", default=0, type="int")
    parser.add_option("--clients", dest="clients", metavar="CLIENTS", default=64, type="int")
    parser.add_option("--workload", "-w", dest="workload", metavar="WORKLOAD", default=None, type="string")
    parser.add_option("--host", dest="hosts", metavar="HOST:PORT", action="append", default=[], type="string")
    parser.add_option("--quiet", dest="quiet", action="store_true", default=False)
    parser.add_option("--duration", dest="duration", type="int", default=0)
    parser.add_option("--op-count", dest="op_count", type="int", default=0)
    parser.add_option("--ignore-errors", dest="ignore_errors", action="store_true", default=False)
    (parsed_options, args) = parser.parse_args()

    if args:
        print("no positional arguments supported", file=sys.stderr)
        sys.exit(1)

    options = {
        "clients": parsed_options.clients,
        "ops_per_sec": parsed_options.ops_per_sec,
        "ops_per_conn": parsed_options.ops_per_conn,
        "quiet": parsed_options.quiet,
        "hosts": [],
        "duration": parsed_options.duration,
        "op_count": parsed_options.op_count,
        "ignore_errors": parsed_options.ignore_errors,
        "seed": parsed_options.seed,
    }

    for host_port in parsed_options.hosts:
        host, port = host_port.split(":")
        options["hosts"].append((host, int(port)))
    if not options["hosts"]:
        options["hosts"].append(("localhost", 28015))

    if parsed_options.workload is None:
        print("no workload specified", file=sys.stderr)
        sys.exit(1)

    if parsed_options.db_table is None:
        options["db"] = os.environ.get('DB_NAME', 'test')
        options["table"] = os.environ.get('TABLE_NAME', 'stress')
    else:
        options["db"], options["table"] = parsed_options.db_table.split(".")

    with r.connect(options["hosts"][0][0], options["hosts"][0][1]) as connection:
        if options["db"] not in r.db_list().run(connection):
            r.db_create(options["db"]).run(connection)
        if options["table"] not in r.db(options["db"]).table_list().run(connection):
            r.db(options["db"]).table_create(options["table"]).run(connection)

    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'stress_workloads')))
    workload_mod = __import__(parsed_options.workload)
    options["workload"] = workload_mod.Workload(options)

    stress_controller(options)
